from django.utils.encoding import smart_str
from drf_yasg.errors import SwaggerGenerationError
from drf_yasg.inspectors import SwaggerAutoSchema
from drf_yasg.utils import merge_params, get_object_classes
from rest_framework.parsers import FileUploadParser
from rest_framework.request import is_form_media_type
from rest_framework.schemas import AutoSchema
from rest_framework.utils import formatting

from Wchime.settings import SWAGGER_SETTINGS


def get_consumes(parser_classes):

    parser_classes = get_object_classes(parser_classes)
    parser_classes = [pc for pc in parser_classes if not issubclass(pc, FileUploadParser)]
    media_types = [parser.media_type for parser in parser_classes or []]
    return media_types


def get_summary(string):
    if string is not None:
        result = string.strip().replace(" ", "").split("\n")
        return result[0]


class CustomAutoSchema(AutoSchema):
    def get_description(self, path, method):
        view = self.view
        return self._get_description_section(view, 'tags', view.get_view_description())


class CustomSwaggerAutoSchema(SwaggerAutoSchema):

    def get_tags(self, operation_keys=None):
        tags = super().get_tags(operation_keys)
        # print(tags)
        if "api" in tags and operation_keys:
            #  `operation_keys` 内容像这样 ['v1', 'prize_join_log', 'create']
            tags[0] = operation_keys[SWAGGER_SETTINGS.get('AUTO_SCHEMA_TYPE', 2)]
        ca = CustomAutoSchema()
        ca.view = self.view
        tag = ca.get_description(self.path, 'get') or None
        if tag:
            # tags.append(tag)
            tags[0] = tag
        # print('===', tags)
        return tags

    def get_summary_and_description(self):
        description = self.overrides.get('operation_description', None)
        summary = self.overrides.get('operation_summary', None)
        # print(description, summary)
        if description is None:
            description = self._sch.get_description(self.path, self.method) or ''
            description = description.strip().replace('\r', '')

            if description and (summary is None):
                # description from docstring... do summary magic
                summary, description = self.split_summary_from_description(description)
            # print('====', summary, description)
        if summary is None:
            summary = description
        return summary, description

    def get_consumes_form(self):

        return get_consumes(self.get_parser_classes())

    def add_manual_parameters(self, parameters):
        """
        重写这个函数，让他能解析json，也可以解析表单
        """
        manual_parameters = self.overrides.get('manual_parameters', None) or []

        if manual_parameters:
            parameters = []

        if any(param.in_ == openapi.IN_BODY for param in manual_parameters):  # pragma: no cover
            raise SwaggerGenerationError("specify the body parameter as a Schema or Serializer in request_body")
        if any(param.in_ == openapi.IN_FORM for param in manual_parameters):  # pragma: no cover
            has_body_parameter = any(param.in_ == openapi.IN_BODY for param in parameters)

            if has_body_parameter or not any(is_form_media_type(encoding) for encoding in self.get_consumes_form()):
                raise SwaggerGenerationError("cannot add form parameters when the request has a request body; "
                                             "did you forget to set an appropriate parser class on the view?")
            if self.method not in self.body_methods:
                raise SwaggerGenerationError("form parameters can only be applied to "
                                             "(" + ','.join(self.body_methods) + ") HTTP methods")

        return merge_params(parameters, manual_parameters)


# --------------------------------------------------------------------------------------------------------------

from rest_framework import serializers
from drf_yasg import openapi
from rest_framework.relations import PrimaryKeyRelatedField
from rest_framework.fields import ChoiceField


def serializer_to_swagger(ser_model, get_req=False):
    '''
    序列化转成openapi的形式
    '''
    if ser_model is None and get_req is True:
        return {}, []
    elif ser_model is None and get_req is False:
        return {}
    dit = {}
    serializer_field_mapping = {
        ChoiceField: openapi.TYPE_INTEGER,
        PrimaryKeyRelatedField: openapi.TYPE_INTEGER,
        serializers.IntegerField: openapi.TYPE_INTEGER,
        serializers.BooleanField: openapi.TYPE_BOOLEAN,
        serializers.CharField: openapi.TYPE_STRING,
        serializers.DateField: openapi.TYPE_STRING,
        serializers.DateTimeField: openapi.TYPE_STRING,
        serializers.DecimalField: openapi.TYPE_NUMBER,
        serializers.DurationField: openapi.TYPE_STRING,
        serializers.EmailField: openapi.TYPE_STRING,
        serializers.ModelField: openapi.TYPE_OBJECT,
        serializers.FileField: openapi.TYPE_STRING,
        serializers.FloatField: openapi.TYPE_NUMBER,
        serializers.ImageField: openapi.TYPE_STRING,
        serializers.SlugField: openapi.TYPE_STRING,
        serializers.TimeField: openapi.TYPE_STRING,
        serializers.URLField: openapi.TYPE_STRING,
        serializers.UUIDField: openapi.TYPE_STRING,
        serializers.IPAddressField: openapi.TYPE_STRING,
        serializers.FilePathField: openapi.TYPE_STRING,
    }
    fields = ser_model().get_fields()
    if get_req:
        required = []
        for k, v in fields.items():
            description = getattr(v, 'label', '')
            if isinstance(v, serializers.SerializerMethodField) or getattr(v, 'source'):
                continue
            elif isinstance(v, ChoiceField):
                description += str(dict(getattr(v, 'choices', {})))

            if getattr(v, 'required', True) is not False:
                required.append(k)
            typ = serializer_field_mapping.get(type(v), openapi.TYPE_STRING)
            dit[k] = openapi.Schema(description=description, type=typ)
        return dit, required
    else:
        for k, v in fields.items():
            description = getattr(v, 'label', '')
            if isinstance(v, ChoiceField):
                description += str(dict(getattr(v, 'choices', {})))
            elif isinstance(v, serializers.SerializerMethodField):
                continue
            typ = serializer_field_mapping.get(type(v), openapi.TYPE_STRING)
            dit[k] = openapi.Schema(description=description, type=typ)

        return dit


def serializer_to_req_form_swagger(ser_model, filter_fields):
    li = list()
    serializer_field_mapping = {
        ChoiceField: openapi.TYPE_INTEGER,
        PrimaryKeyRelatedField: openapi.TYPE_INTEGER,
        serializers.IntegerField: openapi.TYPE_INTEGER,
        serializers.BooleanField: openapi.TYPE_BOOLEAN,
        serializers.CharField: openapi.TYPE_STRING,
        serializers.DateField: openapi.TYPE_STRING,
        serializers.DateTimeField: openapi.TYPE_STRING,
        serializers.DecimalField: openapi.TYPE_NUMBER,
        serializers.DurationField: openapi.TYPE_STRING,
        serializers.EmailField: openapi.TYPE_STRING,
        serializers.ModelField: openapi.TYPE_OBJECT,
        serializers.FileField: openapi.TYPE_FILE,
        serializers.FloatField: openapi.TYPE_NUMBER,
        serializers.ImageField: openapi.TYPE_FILE,
        serializers.SlugField: openapi.TYPE_STRING,
        serializers.TimeField: openapi.TYPE_STRING,
        serializers.URLField: openapi.TYPE_STRING,
        serializers.UUIDField: openapi.TYPE_STRING,
        serializers.IPAddressField: openapi.TYPE_STRING,
        serializers.FilePathField: openapi.TYPE_STRING,
    }
    fields = ser_model().get_fields()
    for k, v in fields.items():
        if k in filter_fields:
            continue
        description = getattr(v, 'label', '')
        if isinstance(v, serializers.SerializerMethodField) or getattr(v, 'source'):
            continue
        elif isinstance(v, ChoiceField):
            description += str(dict(getattr(v, 'choices', {})))
        req = getattr(v, 'required', True)
        typ = serializer_field_mapping.get(type(v), openapi.TYPE_STRING)
        li.append(openapi.Parameter(name=k, description=description, type=typ, required=req, in_=openapi.IN_FORM))
    return li


class ViewSwagger(object):

    get_req_params = []
    get_req_body = None
    get_res_data = None
    get_res_examples = {'json': {}}
    get_res_description = ' '
    get_res_code = 200
    get_tags = None
    get_operation_description = None

    post_req_params = []
    post_req_body = None
    post_res_data = None
    post_res_examples = {'json': {}}
    post_res_description = ' '
    post_res_code = 200
    post_tags = None
    post_operation_description = None

    put_req_params = []
    put_req_body = None
    put_res_data = None
    put_res_examples = {'json': {}}
    put_res_description = ' '
    put_res_code = 200
    put_tags = None
    put_operation_description = None

    delete_req_params = []
    delete_req_body = None
    delete_res_data = None
    delete_res_examples = {'json': {}}
    delete_res_description = ' '
    delete_res_code = 200
    delete_tags = None
    delete_operation_description = None

    @classmethod
    def req_serialize_schema(cls, serializer):
        return serializer_to_swagger(serializer, get_req=True)

    @classmethod
    def res_serializer_schema(cls, serializer):
        return serializer_to_swagger(serializer, get_req=False)
    @classmethod
    def req_serializer_form_schema(cls, serializer, filter_fields=[]):
        return serializer_to_req_form_swagger(serializer, filter_fields)
    @classmethod
    def get(cls):

        ret = {
            'manual_parameters': cls.get_req_params,
            'request_body': cls.get_req_body,
            'responses': {cls.get_res_code: openapi.Response(description=cls.get_res_description, schema=cls.get_res_data,  examples=cls.get_res_examples)} if cls.get_res_data else None
        }
        return ret

    @classmethod
    def post(cls):
        ret = {
            'manual_parameters': cls.post_req_params,
            'request_body': cls.post_req_body,
            'responses': {
                cls.post_res_code: openapi.Response(description=cls.post_res_description, schema=cls.post_res_data,
                                                   examples=cls.post_res_examples)} if cls.post_res_data else None
        }
        return ret

    @classmethod
    def put(cls):
        ret = {
            'manual_parameters': cls.put_req_params,
            'request_body': cls.put_req_body,
            'responses': {
                cls.put_res_code: openapi.Response(description=cls.put_res_description, schema=cls.put_res_data,
                                                   examples=cls.put_res_examples)} if cls.put_res_data else None
        }
        return ret

    @classmethod
    def delete(cls):
        ret = {
            'manual_parameters': cls.delete_req_params,
            'request_body': cls.delete_req_body,
            'responses': {
                cls.delete_res_code: openapi.Response(description=cls.delete_res_description, schema=cls.delete_res_data,
                                                   examples=cls.delete_res_examples)} if cls.delete_res_data else None
        }
        return ret




